#!/usr/bin/python

''' Functions to estimate dynamic topics '''

import os
import sys
import random
import operator
import numpy as np
import sklearn.preprocessing
import text.util
import unsupervised.nmf
import unsupervised.rankings
# modules for calculating topic model coherence
from numpy import median
import gensim
from gensim.corpora.dictionary import Dictionary
from gensim.models.coherencemodel import CoherenceModel



class TopicCollection:

    def __init__( self, top_terms = 0, threshold = 1e-6 ):
        # settings
        self.top_terms = top_terms
        self.threshold = threshold
        # state
        self.topic_ids = []        
        self.all_weights = []
        self.all_terms = set()        

    def add_topic_model( self, H, terms, window_topic_labels ):
        '''
        Add topics from a window topic model to the collection.
        '''
        k = H.shape[0]
        for topic_index in range(k):
            topic_weights = {}
            # use top terms only (sparse topic representation)?
            if self.top_terms > 0:
                top_indices = np.argsort( H[topic_index,:] )[::-1]
                for term_index in top_indices[0:self.top_terms]:
                    topic_weights[terms[term_index]] = H[topic_index,term_index]
                    self.all_terms.add( terms[term_index] )
            # use dense window topic vectors
            else:
                total_weight = 0.0
                for term_index in range(len(terms)):
                    total_weight += H[topic_index,term_index]
                for term_index in range(len(terms)):
                    w = H[topic_index,term_index] / total_weight
                    if w >= self.threshold:
                        topic_weights[terms[term_index]] = H[topic_index,term_index]
                        self.all_terms.add( terms[term_index] )
            self.all_weights.append( topic_weights )
            self.topic_ids.append( window_topic_labels[topic_index] )

    def create_matrix( self ):
        '''
        Create the topic-term matrix from all window topics that have been added so far.
        '''
        # map terms to column indices
        all_terms = list(self.all_terms)
        M = np.zeros( (len(self.all_weights), len(all_terms)) )
        term_col_map = {}
        for term in all_terms:
            term_col_map[term] = len(term_col_map)
        # populate the matrix in row-order
        row = 0
        for topic_weights in self.all_weights:
            for term in topic_weights.keys():
                M[row,term_col_map[term]] = topic_weights[term]
            row +=1
        # normalize the matrix rows to L2 unit length
        normalizer = sklearn.preprocessing.Normalizer(norm='l2', copy=True)
        normalizer.fit(M)
        M = normalizer.transform(M)
        return (M,all_terms)

class FitDynamic:
    
    def __init__(self, window_topics, maxiter=200, random_seed=1000):
        # doc_ids, terms, term_rankings, partition, W, H, window_topic_labels
        self.window_topics = window_topics
        self.maxiter = maxiter
        self.random_seed = random_seed
        self.fit_results = None
    
    def _create_topic_term_matrix(self):
        # Collect topic window results
        self.collection = TopicCollection()
        
        for window_model in self.window_topics:
            # Parse elements for readability
            H = window_model['window_results'][5]
            terms = window_model['window_results'][1]
            window_topic_labels = window_model['window_results'][6]
            
            # Add window topic model to collection
            self.collection.add_topic_model(H, terms, window_topic_labels)
        
        # Build topic term matrix
        self.M, self.all_terms = self.collection.create_matrix()
    
    def fit(self, k, verbose=False, keywords=10):
        # Create the data for the dynamic topic model
        self._create_topic_term_matrix()
        
        # NMF implementation
        impl = unsupervised.nmf.SklNMF(max_iters = self.maxiter, 
                                       init_strategy = "nndsvd", 
                                       random_seed = self.random_seed )
        
        # Fit NMF
        impl.apply(self.M, k)
        
        # Create a disjoint partition of documents
        partition = impl.generate_partition()
        
        # Create topic labels
        topic_labels = []
        for i in range( k ):
            topic_labels.append("D%02d" % (i+1))
            
        # Create term rankings for each topic
        term_rankings = []
        for topic_index in range(k):        
            ranked_term_indices = impl.rank_terms(topic_index)
            term_ranking = [self.all_terms[i] for i in ranked_term_indices]
            term_rankings.append(term_ranking)
        
        if verbose:
            print(unsupervised.rankings.format_term_rankings(term_rankings, top = keywords))
        
        self.fit_results = [self.collection.topic_ids, self.all_terms, term_rankings, partition, 
                            impl.W, impl.H, topic_labels]
    
    def get_dynamic_topics(self, num_terms = 10):
        ''' Returns the top n dynamic topic keywords from a fitted
            model. 
            
            Inputs
            ------
            num_terms = the number of keywords to return.
            
            Output
            ------
            Returns a list of dictionaries with topic ids and keywords.
        '''
        
        # Check if there's a fitted model
        if self.fit_results == None:
            raise ValueError("You need to fit a model prior to viewing top keywords.")
        
        keywords = []
        for i,terms in enumerate(self.fit_results[2]):
            keywords.append({'topic': i, 'keywords': ' '.join(terms[0:num_terms])})
        
        return keywords
    
    def _split(self):
        # Get indices to split the W matrix
        start_ = 0
        idx = []
        for window in self.window_topics:
            # Split
            end_ = start_ + window['k']
            idx.append((start_, end_))   
            
            # Update the start of the interval
            start_ = end_
        
        return idx
    
    def get_document_topics(self):
        ''' Takes a fitted dynamic topic model and returns a 
            document-topic matrix. The documents weights are 
            normalized to sum to 1.
            
            Output
            ------
            The document topics returned as a list of python lists.
            Each list represents a document in has the form:
            [origninal_document_id, id_created_for_nmf, ...topic weights....]
        '''
        
        # Get dynamic weights, W, for the dynamic topics. We "split"
        # the matrix into seperate windows (which may be irregular)
        splits = self._split()
        W = []
        for idx in splits:
            W.append(self.fit_results[4][idx[0]:idx[1], :])
        
        # Normalize dynmaic weights
        W_norm = [w/np.sum(w, axis=1, keepdims=True) for w in W]
        
        # Get window topic weights
        W_win_norm = []
        document_labels = []
        for window in self.window_topics:
            window_res = window['window_results']
            
            # Normalize weights
            W_win = window_res[4]
            W_win_norm.append(W_win/np.sum(W_win, axis=1, keepdims=True))
            
            # Save dynamic-nmf document labels
            document_labels.append(np.asarray(window_res[0]))
        
        # Estimate dynamic topic document weights
        W_dynamic = []
        for i,w in enumerate(W_norm):
            W_dynamic.append(np.dot(W_win_norm[i], w))
        
        # Return document dynamic topic matrix with 
        # label information
        labels = np.hstack(tuple(document_labels))
        dtm = np.vstack(tuple(W_dynamic))
        
        # Attach document labels
        document_term_matrix = np.hstack((labels[:,None], dtm)).tolist()
        
        return document_term_matrix

class SelectDynamicTopics:
    
    def __init__(self, time_slice_data, window_topics, d_min, d_max, step):
        self.time_slice_data = time_slice_data
        self.window_topics = window_topics
        self.d_min = d_min
        self.d_max = d_max
        self.step = step
    
    def get_median_coherence(self, coherence_metric = 'c_v', num_terms = 15):
        ''' Get coherence using the gensim Topic Coherence pipeline '''
        
        # Format text for gensim
        texts_for_dictionary = [row[1].split(' ') for row in self.time_slice_data]   
        dictionary = Dictionary(texts_for_dictionary)
        
        cv_results = []
        for k in range(self.d_min, self.d_max, self.step):
            fd = FitDynamic(self.window_topics)
            fd.fit(k, verbose = False)
            keywords = fd.get_dynamic_topics(num_terms)
            topics = [row['keywords'].split(' ') for row in keywords]
            co = CoherenceModel(topics=topics,
                                texts=texts_for_dictionary,
                                dictionary=dictionary,
                                coherence=coherence_metric)
            
            print("Calculating median coherence for dynamic topics k = {}".format(k))
            coherence_by_topic = co.get_coherence_per_topic()
            median_coherence = median(coherence_by_topic)
            
            cv_results.append({'k': k, 
                               'median_cv': median_coherence})
        
        # Sort prior to returning results
        sorted_cv_results = sorted(cv_results, 
                                   key=operator.itemgetter('median_cv'), 
                                   reverse=True)
        return sorted_cv_results

    def get_coherence(self, coherence_metric = 'c_v', num_terms = 15):
        ''' Get coherence using the gensim Topic Coherence pipeline '''
        
        # Format text for gensim
        texts_for_dictionary = [row[1].split(' ') for row in self.time_slice_data]   
        dictionary = Dictionary(texts_for_dictionary)
        
        cv_results = []
        for k in range(self.d_min, self.d_max, self.step):
            fd = FitDynamic(self.window_topics)
            fd.fit(k, verbose = False)
            keywords = fd.get_dynamic_topics(num_terms)
            topics = [row['keywords'].split(' ') for row in keywords]
            co = CoherenceModel(topics=topics,
                                texts=texts_for_dictionary,
                                dictionary=dictionary,
                                coherence=coherence_metric)
            
            print("Calculating coherence for dynamic topics k = {}".format(k))
            coherence_by_topic = co.get_coherence_per_topic()
            cv_results.append({'k': k, 
                               'k_cv': coherence_by_topic,
                               'keywords': keywords})
        
        return cv_results

